batch_size = 8
n_epochs = 8000Experiment 3: Learning to regenerate
Conducting the third experiment which is training the automaton to regenerate its form following a corruption, such as deleting a part or creating a hole
path = '../images/emoji_u1f98e.png'img_tensor = load_image(path)
Pool Training with holes
We want to add the ability to the pool of samples to generate corrupted samples as well, meaning samples containing random holes.
We start by creating a function that receives a batch of images and creates a random hole in each one of them.
create_hole
create_hole (batch)
corrupted_image = create_hole(img_tensor)/home/vvr/anaconda3/envs/fastai/lib/python3.10/site-packages/torch/functional.py:504: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at ../aten/src/ATen/native/TensorShape.cpp:3483.)
return _VF.meshgrid(tensors, **kwargs) # type: ignore[attr-defined]
plt.imshow(corrupted_image[0].detach().cpu().permute(1, 2, 0))
plt.show()
CorruptedPool
CorruptedPool (pool_size=1024, loss_fn=None, device='cpu')
Initialize self. See help(type(self)) for accurate signature.
loss_fn = partial(mse, target=img_tensor.repeat(batch_size, 1, 1, 1).to(def_device))pool = CorruptedPool(1024, loss_fn=loss_fn)
batch = pool.sample_with_damage()
vis_batch(batch)
Training Loop
# Instantiate the model
ca = CAModel(CHANNEL_N).to(def_device)
# Optimization
lr = 2e-3
lr_gamma = 0.9999
betas = (0.5, 0.5)
optimizer = torch.optim.Adam(ca.parameters(), lr=lr, betas=betas)
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, lr_gamma)
target = img_tensor.repeat(batch_size, 1, 1, 1)for i in tqdm(range(n_epochs)):
# zero the optimizer
optimizer.zero_grad()
# set the number of steps to take
steps = torch.randint(64, 96, (1,)).item()
# sample the pool to get the input
model_in = pool.sample_with_damage()
# activate the model
res = ca(model_in, steps=steps)
# calculate the loss
loss = F.mse_loss(res[:, :4], target) # we only care about the RGBA channels
# update the pool
pool.update(res)
# log the loss
if i%500 == 0:
print(f"Epoch: {i} Loss: {loss.item()}")
# backpropagate the loss and update the weights
loss.backward()
optimizer.step()
scheduler.step() 0%| | 1/8000 [00:00<1:29:22, 1.49it/s] 6%|██▌ | 503/8000 [00:37<08:42, 14.35it/s] 13%|████▉ | 1003/8000 [01:13<08:13, 14.19it/s] 19%|███████▎ | 1503/8000 [01:49<07:29, 14.45it/s] 25%|█████████▊ | 2003/8000 [02:25<07:06, 14.06it/s] 31%|████████████▏ | 2503/8000 [03:00<06:32, 14.02it/s] 38%|██████████████▋ | 3003/8000 [03:36<05:46, 14.41it/s] 44%|█████████████████ | 3503/8000 [04:12<05:45, 13.01it/s] 50%|███████████████████▌ | 4003/8000 [04:48<04:43, 14.12it/s] 56%|█████████████████████▉ | 4503/8000 [05:24<04:09, 13.99it/s] 63%|████████████████████████▍ | 5003/8000 [06:00<03:41, 13.51it/s] 69%|██████████████████████████▊ | 5503/8000 [06:36<02:45, 15.10it/s] 75%|█████████████████████████████▎ | 6003/8000 [07:11<02:14, 14.81it/s] 81%|███████████████████████████████▋ | 6503/8000 [07:47<01:45, 14.19it/s] 88%|██████████████████████████████████▏ | 7003/8000 [08:23<01:15, 13.26it/s] 94%|████████████████████████████████████▌ | 7503/8000 [08:59<00:36, 13.63it/s]100%|███████████████████████████████████████| 8000/8000 [09:35<00:00, 13.90it/s]
Epoch: 0 Loss: 0.11324034631252289
Epoch: 500 Loss: 0.029659921303391457
Epoch: 1000 Loss: 0.01154416985809803
Epoch: 1500 Loss: 0.009176425635814667
Epoch: 2000 Loss: 0.004854139406234026
Epoch: 2500 Loss: 0.0035146409645676613
Epoch: 3000 Loss: 0.0019474619766697288
Epoch: 3500 Loss: 0.0017854450270533562
Epoch: 4000 Loss: 0.0013569763395935297
Epoch: 4500 Loss: 0.0009505663765594363
Epoch: 5000 Loss: 0.0013342727907001972
Epoch: 5500 Loss: 0.001305947545915842
Epoch: 6000 Loss: 0.00037199087091721594
Epoch: 6500 Loss: 0.0005692397826351225
Epoch: 7000 Loss: 0.00025373895186930895
Epoch: 7500 Loss: 0.0003689015575218946
images = ca.grow_animation(seed, 200)
display_animation(images)The automato manages to maintain it’s shape as time passes.
Visualize a batch from the updated pool.
During the training process we update the samples inside the pool. Let’s visualize how a sample would look like after the training process is completed.
batch = pool.sample_with_damage()vis_batch(batch)Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).

Test the reconstruction ability of a model
# Generate an automato which we want to corrupt
generated = ca(seed, steps=96).detach()Corruption 1:
corrupted_input1 = generated.clone()
corrupted_input1[..., 20:] = 0
plt.imshow(corrupted_input1[0].permute(1, 2, 0)[:, :, :4].cpu())Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
<matplotlib.image.AxesImage>

images = ca.grow_animation(corrupted_input1, 300)
display_animation(images)Corruption 2:
corrupted_input2 = generated.clone()
corrupted_input2[..., :20,:] = 0
plt.imshow(corrupted_input2[0].permute(1, 2, 0)[:, :, :4].cpu())Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
<matplotlib.image.AxesImage>

images = ca.grow_animation(corrupted_input2, 300)
display_animation(images)